import lightning as L
from typing import Optional
from torch.utils.data import DataLoader
from datasets import load_from_disk
from transformers import DataCollatorForLanguageModeling, AutoTokenizer
from .utils import registry


class CausalLMDataModule(L.LightningDataModule):
    def __init__(self, 
                 dataset_path: str, 
                 batch_size: int, 
                 tokenizer_name_or_path: str, 
                 num_workers: int = 4,
                 val_size: float = 0.01) -> None:
        super().__init__()
        self.save_hyperparameters()
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
        self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        self.collate_fn = DataCollatorForLanguageModeling(mlm=False, tokenizer=self.tokenizer)
        
        
    def prepare_data(self) -> None:
        pass
        
        
    def setup(self, stage: Optional[str] = 'fit') -> None:
        self.ds = load_from_disk(self.hparams.dataset_path)
        self.ds.set_format('torch')
        self.ds = self.ds['train'].train_test_split(test_size=self.hparams.val_size)
        self.ds['val'] = self.ds.pop('test')
    
    
    def train_dataloader(self):
        return DataLoader(dataset=self.ds['train'], collate_fn=self.collate_fn, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers)
    
    
    def val_dataloader(self):
        return DataLoader(dataset=self.ds['val'],collate_fn=self.collate_fn, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers)
    

@registry.datamodules('language_modeling')
def build_lm_datamodule(dataset_path: str, batch_size: int, tokenizer_name_or_path: str, val_size: float = 0.01, num_workers: int = 4):
    return CausalLMDataModule(dataset_path=dataset_path, batch_size=batch_size, tokenizer_name_or_path=tokenizer_name_or_path, val_size=val_size, num_workers=num_workers)